package cn.spark.study.sql;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.api.java.function.VoidFunction;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.hive.HiveContext;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import scala.Tuple2;

import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * jdbc数据源
 *
 * @author jun.zhang6
 * @date 2020/11/15
 */
public class JdbcDataSource {
    public static void main(String[] args) {
        SparkConf conf = new SparkConf().setAppName("JdbcDataSource");
        JavaSparkContext sc = new JavaSparkContext(conf);
        HiveContext hiveContext = new HiveContext(sc.sc());

        Map<String, String> options = new HashMap<String, String>();
        options.put("url", "jdbc:mysql://spark1:3306/testdb");
        options.put("dbtable", "student_infos");

        //通过read()方法，将mysql中的数据加载为DataFrame
        DataFrame studentInfosDF = hiveContext.read().format("jdbc").options(options).load();

        options.put("dbtable", "student_scores");
        DataFrame studentScoresDF = hiveContext.read().format("jdbc").options(options).load();

        //将两个DataFrame转换为RDD，执行join操作
        JavaPairRDD<String, Tuple2<Integer, Integer>> studentsRDD =
                studentInfosDF.javaRDD().mapToPair(new PairFunction<Row, String, Integer>() {
                    @Override
                    public Tuple2<String, Integer> call(Row row) throws Exception {
                        return new Tuple2<String, Integer>(row.getString(0), Integer.valueOf(String.valueOf(row.getLong(1))));
                    }
                }).join(studentScoresDF.javaRDD().mapToPair(new PairFunction<Row, String, Integer>() {
                    @Override
                    public Tuple2<String, Integer> call(Row row) throws Exception {
                        return new Tuple2<String, Integer>(row.getString(0), Integer.parseInt(String.valueOf(row.getLong(1))));
                    }
                }));

        JavaRDD<Row> studentRowsRDD = studentsRDD.map(new Function<Tuple2<String, Tuple2<Integer, Integer>>, Row>() {
            @Override
            public Row call(Tuple2<String, Tuple2<Integer, Integer>> tuple) throws Exception {
                return RowFactory.create(tuple._1, tuple._2._1, tuple._2._2);
            }
        });

        //过滤出分数大于80分的数据
        JavaRDD<Row> filteredStudentRowsRDD = studentRowsRDD.filter(new Function<Row, Boolean>() {
            @Override
            public Boolean call(Row row) throws Exception {
                return row.getInt(2) > 80;
            }
        });

        //转化为DataFrame
        List<StructField> structFields = new ArrayList<StructField>();
        structFields.add(DataTypes.createStructField("name", DataTypes.StringType, true));
        structFields.add(DataTypes.createStructField("age", DataTypes.IntegerType, true));
        structFields.add(DataTypes.createStructField("score", DataTypes.IntegerType, true));
        StructType structType = DataTypes.createStructType(structFields);

        DataFrame studentsDF = hiveContext.createDataFrame(filteredStudentRowsRDD, structType);

        Row[] rows = studentsDF.collect();

        for (Row row : rows) {
            System.out.println(row);
        }

        //将DataFrame中的数据保存到mysql中
        //这种方式在企业中很常用，有可能插入mysql，hbase，redis缓存中
        studentsDF.javaRDD().foreach(new VoidFunction<Row>() {
            @Override
            public void call(Row row) throws Exception {
                String sql = "insert into good_student_infos values("
                        + "'" + row.getString(0) + "',"
                        + Integer.valueOf(String.valueOf(row.get(1))) + ","
                        + Integer.valueOf(String.valueOf(row.get(2))) + ")";

                Class.forName("com.mysql.jdbc.Driver");

                Connection conn = null;
                Statement stmt = null;
                try {
                    conn = DriverManager.getConnection(
                            "jdbc:mysql://spark1:3306/testdb", "", "");
                    stmt = conn.createStatement();
                    stmt.executeUpdate(sql);
                } catch (Exception e) {
                    e.printStackTrace();
                } finally {
                    if (stmt != null) {
                        stmt.close();
                    }
                    if (conn != null) {
                        conn.close();
                    }
                }
            }
        });

        sc.close();
    }
}
